package org.zjvis.datascience.common.graph.algo;

import java.util.*;
import java.util.stream.Collectors;

/**
 * @description DijkstraShortestPath
 * @date 2021-12-29
 */
public class DijkstraShortestPath {
    private static final Double INF = Double.MAX_VALUE;   // 最大值

    private Object source;
    private Map<Object, Double> distances;
    private Double maxDistance = 0.0D;

    private Map<Object, List<EdgeWrapper>> neighborMap;

    private Map<Object, EdgeWrapper> predecessors;

    private Boolean isDirected;
    private Boolean isNormalized;


    public DijkstraShortestPath(Object source, Map<Object, List<EdgeWrapper>> neighborMap, Boolean isDirected, Boolean isNormalized) {
        this.source = source;
        this.neighborMap = neighborMap;
        this.isDirected = isDirected;
        this.isNormalized = isNormalized;
        this.distances = new HashMap<>();
        this.predecessors = new HashMap<>();
    }

    public void execute() {
        execute(null);
    }

    public void execute(Object target) {
        if (target != null && target.equals(source)) {
            return;
        }
        Set<Object> unsettledNodes = new HashSet();
        Set<Object> settledNodes = new HashSet();
        Iterator<Object> iterator =neighborMap.keySet().iterator();
        Object minDistanceNode;
        while (iterator.hasNext()) {
            minDistanceNode = iterator.next();
            distances.put(minDistanceNode, INF);
        }
        distances.put(source, 0.0D);
        unsettledNodes.add(source);

        while (!unsettledNodes.isEmpty()) {
            Double minDistance = INF;
            minDistanceNode = null;
            Iterator<Object> unsettledIter = unsettledNodes.iterator();
            while (unsettledIter.hasNext()) {
                Object k = unsettledIter.next();
                Double dist = distances.get(k);
                if (minDistanceNode == null) {
                    minDistanceNode = k;
                }
                if (dist.compareTo(minDistance) < 0) {
                    minDistance = dist;
                    minDistanceNode = k;
                }
            }
            unsettledNodes.remove(minDistanceNode);
            settledNodes.add(minDistanceNode);
            Iterator<EdgeWrapper> nbrEdgeIter = neighborMap.get(minDistanceNode).iterator();
            while (nbrEdgeIter.hasNext()) {
                EdgeWrapper nbrEdge = nbrEdgeIter.next();
                Object nbr = nbrEdge.target;
                if (!settledNodes.contains(nbr)) {
                    Double dist = getShortestDistance(minDistanceNode) + nbrEdge.weight;
                    if (getShortestDistance(nbr) > dist) {
                        distances.put(nbr, dist);
                        predecessors.put(nbr, nbrEdge);
                        unsettledNodes.add(nbr);
                        maxDistance = Math.max(maxDistance, dist);
                        if (target != null && target.equals(nbr)) {
                            return;
                        }
                    }
                }
            }
        }
        if (isNormalized) {
            calculateCorrection();
        }
    }

    private void calculateCorrection() {
        Iterator<Object> idIter = new ArrayList<>(distances.keySet()).iterator();
        while (idIter.hasNext()) {
            Object id = idIter.next();
            if (distances.get(id).equals(INF)) {
                distances.remove(id);
            } else {
                distances.replace(id, maxDistance == 0.0D ? distances.get(id) : distances.get(id) / maxDistance);
            }
        }
    }

    private Double getShortestDistance(Object target) {
        Double d = distances.get(target);
        return d == null ? INF : d;
    }

    public Map<Object, Double> getDistances() {
        return distances;
    }

    public Double getMaxDistance() {
        return maxDistance;
    }

    public Map<Object, EdgeWrapper> getPredecessors() {
        return predecessors;
    }

    public List<Object> getPath(Object target) {
        List<Object> path = new ArrayList<>();
        if (predecessors.get(target) == null) {
            return path;
        }
        path.add(target);
        while (!target.equals(source)) {
            target = predecessors.get(target).source;
            path.add(target);
        }
        Collections.reverse(path);
        return path;
    }
}
